{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script 模型本地函数\n",
    "ONNX 中的模型可能包含模型本地函数。当将 onnxscript 函数转换为 ModelProto 时，默认行为是将所有被传递调用的 function-ops 的函数定义作为生成模型中的模型本地函数包含在内（对于这些函数，已经看到了 onnxscript 函数定义）。调用者可以通过明确提供要包含在生成模型中的 FunctionProtos 列表来覆盖此行为。\n",
    "\n",
    "首先，让我们定义一个调用其他 ONNXScript 函数的 ONNXScript 函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "\n",
    "from onnxscript import FLOAT, script\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript.values import Opset\n",
    "\n",
    "# A dummy opset used for model-local functions\n",
    "local = Opset(\"local\", 1)\n",
    "\n",
    "\n",
    "@script(local, default_opset=op)\n",
    "def diff_square(x, y):\n",
    "    diff = x - y\n",
    "    return diff * diff\n",
    "\n",
    "\n",
    "@script(local)\n",
    "def sum(z):\n",
    "    return op.ReduceSum(z, keepdims=1)\n",
    "\n",
    "\n",
    "@script()\n",
    "def l2norm(x: FLOAT[\"N\"], y: FLOAT[\"N\"]) -> FLOAT[1]:  # noqa: F821\n",
    "    return op.Sqrt(sum(diff_square(x, y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们看看默认生成的模型是什么样的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<\n",
      "   ir_version: 8,\n",
      "   opset_import: [\"local\" : 1, \"\" : 15]\n",
      ">\n",
      "l2norm (float[N] x, float[N] y) => (float[1] return_val) {\n",
      "   tmp = local.diff_square (x, y)\n",
      "   tmp_0 = local.sum (tmp)\n",
      "   return_val = Sqrt (tmp_0)\n",
      "}\n",
      "<\n",
      "  domain: \"local\",\n",
      "  opset_import: [\"\" : 15]\n",
      ">\n",
      "sum (z) => (return_val)\n",
      "{\n",
      "   return_val = ReduceSum <keepdims: int = 1> (z)\n",
      "}\n",
      "<\n",
      "  domain: \"local\",\n",
      "  opset_import: [\"\" : 15]\n",
      ">\n",
      "diff_square (x, y) => (return_val)\n",
      "{\n",
      "   diff = Sub (x, y)\n",
      "   return_val = Mul (diff, diff)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "model = l2norm.to_model_proto()\n",
    "print(onnx.printer.to_text(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，让我们明确指定要包含哪些函数。首先，生成一个不包含模型本地函数的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<\n",
      "   ir_version: 8,\n",
      "   opset_import: [\"local\" : 1, \"\" : 15]\n",
      ">\n",
      "l2norm (float[N] x, float[N] y) => (float[1] return_val) {\n",
      "   tmp = local.diff_square (x, y)\n",
      "   tmp_0 = local.sum (tmp)\n",
      "   return_val = Sqrt (tmp_0)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "model = l2norm.to_model_proto(functions=[])\n",
    "print(onnx.printer.to_text(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，生成一个包含一个模型本地函数的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<\n",
      "   ir_version: 8,\n",
      "   opset_import: [\"local\" : 1, \"\" : 15]\n",
      ">\n",
      "l2norm (float[N] x, float[N] y) => (float[1] return_val) {\n",
      "   tmp = local.diff_square (x, y)\n",
      "   tmp_0 = local.sum (tmp)\n",
      "   return_val = Sqrt (tmp_0)\n",
      "}\n",
      "<\n",
      "  domain: \"local\",\n",
      "  opset_import: [\"\" : 15]\n",
      ">\n",
      "sum (z) => (return_val)\n",
      "{\n",
      "   return_val = ReduceSum <keepdims: int = 1> (z)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "model = l2norm.to_model_proto(functions=[sum])\n",
    "print(onnx.printer.to_text(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
